# -*- coding: utf-8 -*-
"""
Created on Mon Nov  9 13:52:03 2015

@author: ppradeep
"""
import csv
from collections import Counter
import numpy

## Funtion to calculate metrics
def metrics (t_p, t_n, f_p, f_n):
    total = float(t_p + t_n + f_p + f_n)
    acc = round(100*float(t_p + t_n)/float(total),2)
    sens = round(100*float(t_p)/float(t_p + f_p),2)
    spec = round(100*float(t_n)/float(t_n + f_n),2)
    ba = round((sens+spec)/2,2)
    p_o = float(t_p + t_n)/total
    p_e = ((t_p + f_n)/total)*((t_p + f_p)/total) + ((f_p + t_n)/total)*((f_n + t_n)/total)
    kappa = round(((p_o - p_e)/(1 - p_e)), 2)
    return total, acc, sens, spec, ba, kappa
  
path = 'W:/Rapid Tox/'

n1 = open(path+'ReadAcross-Project/Number/RA-HP-LitDataSources-Summary-Dist.csv','w')
writeCSV1 = csv.writer(n1)
writeCSV1.writerow(['Method', 'Cut-Off (No. of Analogs)', 'Cut-Off (Lit sources)', 'Accuracy', 'Balanced Accuracy', 'Kappa Coeff.', 'Sensitivity', 'Specificity', \
                    'True Positives','False Positives', 'True Negatives', 'False Negatives', 'Total'])

n2 = open(path+'ReadAcross-Project/Number/RA-HP-LitDataSources-Detail-Dist.csv','w')                    
writeCSV2 = csv.writer(n2)
writeCSV2.writerow(['Hindered Phenol (CERAPP ID)', 'True Binding', 'Number of Lit. Sources', \
                    'Method', 'Number of Analogs', 'Analogs', 'Analog Prediction'])

####  
# Read the CERAPP data file and create a:
# Dictionary of each casrn and if HP or NHP based on the presence of 1 or more hindered phenolic groups 
# in the chemical (cerapp_data_isHP) and their experimental binding, agonist, antagonist classes.
# Do for a cut-off of lit data source from 1 to 10.
####  
for lit_c in range(1, 11):
    f0 = open(path+'ReadAcross-Project/CERAPP_Phenols.csv','r')
    readCSV0 = csv.reader(f0, delimiter=',')
    readCSV0.next()
    
    all_p = []; all_hp = []
    p_data = {}; hp_data = {}
    
    for line in readCSV0:
        phenol_id = line[0]; all_p.append(phenol_id)
        class_b = line[10] # Binder or not. 0 or 1.
        if int(line[9]) >= lit_c: #More than 4 lit sources (=> also data on binding)
            n_H = line[7] #Number of hindered phenolic groups in the chemical
            n_NH = line[8] #Number of non-hindered phenolic groups in the chemical
            if n_H == '0' and n_NH != '0':
                p_data[phenol_id] = ['NHP', class_b] # Not HP
            else:
                p_data[phenol_id] = ['HP', class_b] # HP
                all_hp.append(phenol_id)
                hp_data[phenol_id] = ['HP', class_b]
        else:
            pass
    f0.close()


    ## Select n analogs from each method
    n_analogs = 10
    ## ****** 1. MoSS MCSS *****
    f1 = open(path+'ReadAcross-Project/DistanceFiles/Cerapp-ReadAcross-MoSS-Ds.csv','r')
    readCSV1 = csv.reader(f1, delimiter=',')
    header = readCSV1.next()
    
    # Read distance data
    dist_mat = []
    cerapp_id_m = [] #List of all ids in MOSS distance file
    i = 0
    for idx, line in enumerate(readCSV1):
        cerapp_id_m.append(line[0])
        end = len(line)
        dist_mat.append([])
        for dist in line[5:end]:
            dist_mat[i].append(dist)
        dist_mat[i].append('') # Need one more extra column to compensate for the blank entry for each chemical by itself.
        i = i + 1
    
    ## Generate full distance matrix
    size = len(dist_mat)   
    full_dist_mat = dist_mat
    for i in range(size):
        full_dist_mat[i][i] = '0.0'
        for j in range(i+1, size):
            full_dist_mat[i][j] = dist_mat[j][i]
    
    #Create a dictionary of each phenol and sort its neighbors by distance
    neighbors_m = {}
    for idx, distances in enumerate(full_dist_mat):
        index_sorted = numpy.argsort(distances)[::-1] #[::-1] tells numpy to iterate of the array backwords, sorting in a descending order
        key = cerapp_id_m[idx]
        i_m = 0
        for index in index_sorted:
            n = cerapp_id_m[index]
            if key in all_hp and key != n and n in p_data.keys() and float(distances[index]) >= 0.70:
                neighbors_m.setdefault(key,[]).append([n, distances[index]])
                i_m = i_m + 1
                if i_m >= n_analogs:
                    break
    
    f1.close()

    ## ***** 2. Pubchem *****
    f2 = open(path+'ReadAcross-Project/DistanceFiles/Cerapp-ReadAcross-Pubchem-TDs.csv','r')
    readCSV2 = csv.reader(f2, delimiter=',')
    header = readCSV2.next()[5:]
    neighbors_p = {}       
    for idx, line in enumerate(readCSV2):
        key = line[0]
        end = len(line)
        distances = line[5:end]
        index_sorted = numpy.argsort(distances)[::-1] #[::-1] tells numpy to iterate of the array backwords, sorting in a descending order
        i_p = 0
        for index in index_sorted:
            n = header[index]
            if key in all_hp and key != n and n in p_data.keys() and float(distances[index]) >= 0.70: #all_hp
                neighbors_p.setdefault(key,[]).append([n, distances[index]])
                i_p = i_p + 1
                if i_p >= n_analogs:
                    break
    f2.close()  
   
    ## ***** 3. Chemotyper *****
    f3 = open(path+'ReadAcross-Project/DistanceFiles/Cerapp-ReadAcross-Chemotyper-TDs.csv','r')
    readCSV3 = csv.reader(f3, delimiter=',')
    header = readCSV3.next()[1:]
    neighbors_c = {}       
    for idx, line in enumerate(readCSV3):
        key = line[0]
        end = len(line)
        distances = line[1:end]
        index_sorted = numpy.argsort(distances)[::-1] #[::-1] tells numpy to iterate of the array backwords, sorting in a descending order
        i_c = 0
        for index in index_sorted:
            n = header[index]
            if key in all_hp and key != n and n in p_data.keys() and float(distances[index]) >= 0.70:
                neighbors_c.setdefault(key,[]).append([n, distances[index]])
                i_c = i_c + 1
                if i_c >= n_analogs:
                    break
    f3.close()   
        
        
 
    ## Calculate the average prediction metrics using read across from all three methods. Prediction for chemical = majority vote of analog. 
    ## Also calculate prediction metrics when the n analogs from each method are combined. Eg. 1 analog each from P, C and M. 
    ## Prediction is the majority vote again.
        
    acc_m = []; acc_p = []; acc_c = []; acc_all = []
    thresholds = range(1, 11)
    for threshold in thresholds:
        true_pos_m = 0; false_pos_m = 0; true_neg_m = 0; false_neg_m = 0
        true_pos_p = 0; false_pos_p = 0; true_neg_p = 0; false_neg_p = 0
        true_pos_c = 0; false_pos_c = 0; true_neg_c = 0; false_neg_c = 0
        true_pos = 0; false_pos = 0; true_neg = 0; false_neg = 0 
              
        for hp in all_hp:  
            truth = hp_data[hp][1]
            p = []; c = []; m = []
            try:
                m = [x[0] for x in neighbors_m[hp]][0:threshold]
                d_m = [x[1] for x in neighbors_m[hp]][0:threshold]
                # Calculate average prediction from MoSS analogs
                pred = []
                for neighbor in m:
                    pred.append(p_data[neighbor][1]) 
                prediction_m = Counter(pred).most_common()[0][0]  
                try:
                    if Counter(pred).most_common()[0][1] == Counter(pred).most_common()[1][1]:
                        prediction_m == 1
                except:
                    pass
                if truth == '1' and prediction_m == '1':
                    true_pos_m = true_pos_m + 1
                if truth == '1' and prediction_m == '0':
                    false_neg_m = false_neg_m + 1
                if truth == '0' and prediction_m == '0':
                    true_neg_m = true_neg_m + 1
                if truth == '0' and prediction_m == '1':
                    false_pos_m = false_pos_m + 1
                writeCSV2.writerow([hp, truth, lit_c, \
                    'MoSS', threshold, m, prediction_m])
            except:
                pass
            
            try:
                p = [x[0] for x in neighbors_p[hp]][0:threshold]
                d_p = [x[1] for x in neighbors_p[hp]][0:threshold]
                
                # Calculate average prediction from PubChem analogs
                pred = []
                for neighbor in p:
                    pred.append(p_data[neighbor][1])
                prediction_p = Counter(pred).most_common()[0][0]  
                try:
                    if Counter(pred).most_common()[0][1] == Counter(pred).most_common()[1][1]:
                        prediction_p == 1
                except:
                    pass
                if truth == '1' and prediction_p == '1':
                    true_pos_p = true_pos_p + 1
                if truth == '1' and prediction_p == '0':
                    false_neg_p = false_neg_p + 1
                if truth == '0' and prediction_p == '0':
                    true_neg_p = true_neg_p + 1
                if truth == '0' and prediction_p == '1':
                    false_pos_p = false_pos_p + 1
                writeCSV2.writerow([hp, truth, lit_c, \
                    'PubChem', threshold, p, prediction_p])
            except:
                pass  
            
            try:
                c = [x[0] for x in neighbors_c[hp]][0:threshold]
                d_c = [x[1] for x in neighbors_c[hp]][0:threshold]
                
                # Calculate average prediction from ToxPrints analogs
                pred = []
                for neighbor in c:
                    pred.append(p_data[neighbor][1])
                prediction_c = Counter(pred).most_common()[0][0] 
                try:
                    if Counter(pred).most_common()[0][1] == Counter(pred).most_common()[1][1]:
                        prediction_c == 1
                except:
                    pass
                if truth == '1' and prediction_c == '1':
                    true_pos_c = true_pos_c + 1
                if truth == '1' and prediction_c == '0':
                    false_neg_c = false_neg_c + 1
                if truth == '0' and prediction_c == '0':
                    true_neg_c = true_neg_c + 1
                if truth == '0' and prediction_c == '1':
                    false_pos_c = false_pos_c + 1
                writeCSV2.writerow([hp, truth, lit_c, \
                    'Chemotyper', threshold, c, prediction_c])
            except:
                pass
           
           # Calculate average prediction from all P/C/M analogs
            neighbors_all = m + p + c
            neighbors_unique = list(set(neighbors_all)) 
            count = len(neighbors_unique)
            try:
                pred = []
                for neighbor in neighbors_unique:
                    pred.append(p_data[neighbor][1])
                prediction = Counter(pred).most_common()[0][0]  
                try:
                    if Counter(pred).most_common()[0][1] == Counter(pred).most_common()[1][1]:
                        prediction == 1
                except:
                    pass
                if truth == '1' and prediction == '1':
                    true_pos = true_pos + 1
                if truth == '1' and prediction == '0':
                    false_neg = false_neg + 1
                if truth == '0' and prediction == '0':
                    true_neg = true_neg + 1
                if truth == '0' and prediction == '1':
                    false_pos = false_pos + 1   
                writeCSV2.writerow([hp, truth, lit_c, \
                    'M & C & P', threshold, neighbors_unique, prediction])
            except:
                pass
        

        metrics_m = metrics(true_pos_m , true_neg_m , false_pos_m , false_neg_m)
        metrics_p = metrics(true_pos_p , true_neg_p , false_pos_p , false_neg_p)
        metrics_c = metrics(true_pos_c , true_neg_c , false_pos_c , false_neg_c)
    
        writeCSV1.writerow(['MoSS', threshold, lit_c, metrics_m[1] , metrics_m[4] , metrics_m[5] , metrics_m[2], metrics_m[3], \
                            true_pos , false_pos , true_neg , false_neg, metrics_m[0]])
        writeCSV1.writerow(['PubChem', threshold, lit_c, metrics_p[1] , metrics_p[4] , metrics_p[5] , metrics_p[2], metrics_p[3], \
                            true_pos , false_pos , true_neg , false_neg, metrics_p[0]])
        writeCSV1.writerow(['Chemotyper', threshold, lit_c, metrics_c[1] , metrics_c[4] , metrics_c[5] , metrics_c[2], metrics_c[3], \
                            true_pos , false_pos , true_neg , false_neg, metrics_c[0]])
                            
        metrics_all = metrics(true_pos , true_neg , false_pos, false_neg)
        writeCSV1.writerow(['MoSS & PubChem & Chemotyper', threshold, lit_c, metrics_all[1] , metrics_all[4] , metrics_all[5] , metrics_all[2], metrics_all[3], \
                            true_pos , false_pos , true_neg , false_neg, metrics_all[0]])
                            
        acc_m.append(metrics_m[1]); acc_p.append(metrics_p[1]); acc_c.append(metrics_c[1]); acc_all.append(metrics_all[1])    

n1.close()  
n2.close()    

#%%
# Bar plot to demonstrate variation in predicitive power with data quality for PubChem
import csv
import numpy as np
from matplotlib import pyplot as plt

# Function to add labels on each bar
def autolabel(rects, label, f):
    # attach some text labels
    for rect in rects:
        height = rect.get_height()
        ax.text(rect.get_x() + rect.get_width()/2. + f, height + 0.3,
                label, ha='center', va='bottom', fontsize = 14)
                
f1 = open(path+'ReadAcross-Project/Number/RA-HP-LitDataSources-Summary-Dist.csv','r')
readCSV1 = csv.reader(f1, delimiter=',')
readCSV1.next()

data = {}
for line in readCSV1:
        data.setdefault(line[2],[]).append([line[0], line[1], line[3], line[4], line[12]])


fig = plt.figure(figsize=(16,8), dpi = 100)
ax = fig.add_subplot(111)

for x in data.keys():
    data_arr = np.array(data[x])
    y1 = float(max(data_arr[:,2])) # 2: Acc, 3: BA # Maximum value of accuracy
    more = (data_arr[data_arr[:,3] == max(data_arr[:,3])]) # Number of analogs at which max achieved
    y2 = float(more[0][3]) # Number of analogs at which max acc achieved
    rects1 = ax.bar(float(x)-0.1, y1, 0.2, color='lightblue', align='center') 
    rects2 = ax.bar(float(x)+0.1, y2, 0.2, color='darkblue', align='center') 

    autolabel(rects2, 'N = %d \n T = %d' %(float(more[0][1]),float(more[0][4])), 0.24) 

ax.set_xlim(0.2,10.7)
ax.set_ylim(60,100)      
ax.set_ylabel('Percentage (%)', fontsize = 22)
ax.set_xlabel('Number of Literature Data Sources (k)', fontsize = 22)

x_ticks = ['$\geq1$', '$\geq2$', '$\geq3$', '$\geq4$', '$\geq5$', '$\geq6$', '$\geq7$', '$\geq8$', '$\geq9$', '$\geq10$']
plt.xticks(range(1,11,1), x_ticks)
plt.yticks(range(60,105,5))
plt.axes()
ax.legend( (rects1[0], rects2[0]), ('Accuracy', 'Balanced Accuracy'), fontsize = 20, loc='upper left')
ax.tick_params(axis='x', labelsize=22)

ax.xaxis.labelpad = 15
ax.yaxis.labelpad = 15

plt.show()
plt.savefig(path+'ReadAcross-Project/Number/RAPredictions_LitDataSources_HP-Dist.png')

n1.close()      
